# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
import os
import time
from typing import Any, List, Optional

import torch
from accelerate import Accelerator
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import EvaluationMode
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
    registry,
    ReplaceableBase,
    run_auto_creation,
)
from pytorch3d.implicitron.tools.stats import Stats
from torch.utils.data import DataLoader, Dataset

from .utils import seed_all_random_engines

logger = logging.getLogger(__name__)


class TrainingLoopBase(ReplaceableBase):
    """
    Members:
        evaluator: An EvaluatorBase instance, used to evaluate training results.
    """

    # pyre-fixme[13]: Attribute `evaluator` is never initialized.
    evaluator: Optional[EvaluatorBase]
    evaluator_class_type: Optional[str] = "ImplicitronEvaluator"

    def run(
        self,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        test_loader: Optional[DataLoader],
        train_dataset: Dataset,
        model: ImplicitronModelBase,
        optimizer: torch.optim.Optimizer,
        scheduler: Any,
        **kwargs,
    ) -> None:
        raise NotImplementedError()

    def load_stats(
        self,
        log_vars: List[str],
        exp_dir: str,
        resume: bool = True,
        resume_epoch: int = -1,
        **kwargs,
    ) -> Stats:
        raise NotImplementedError()


@registry.register
class ImplicitronTrainingLoop(TrainingLoopBase):
    """
    Members:
        eval_only: If True, only run evaluation using the test dataloader.
        max_epochs: Train for this many epochs. Note that if the model was
            loaded from a checkpoint, we will restart training at the appropriate
            epoch and run for (max_epochs - checkpoint_epoch) epochs.
        store_checkpoints: If True, store model and optimizer state checkpoints.
        store_checkpoints_purge: If >= 0, remove any checkpoints older or equal
            to this many epochs.
        test_interval: Evaluate on a test dataloader each `test_interval` epochs.
        test_when_finished: If True, evaluate on a test dataloader when training
            completes.
        validation_interval: Validate each `validation_interval` epochs.
        clip_grad: Optionally clip the gradient norms.
            If set to a value <=0.0, no clipping
        metric_print_interval: The batch interval at which the stats should be
            logged.
        visualize_interval: The batch interval at which the visualizations
            should be plotted
        visdom_env: The name of the Visdom environment to use for plotting.
        visdom_port: The Visdom port.
        visdom_server: Address of the Visdom server.
    """

    # Parameters of the outer training loop.
    eval_only: bool = False
    max_epochs: int = 1000
    store_checkpoints: bool = True
    store_checkpoints_purge: int = 1
    test_interval: int = -1
    test_when_finished: bool = False
    validation_interval: int = 1

    # Gradient clipping.
    clip_grad: float = 0.0

    # Visualization/logging parameters.
    metric_print_interval: int = 5
    visualize_interval: int = 1000
    visdom_env: str = ""
    visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
    visdom_server: str = "http://127.0.0.1"

    def __post_init__(self):
        run_auto_creation(self)

    # pyre-fixme[14]: `run` overrides method defined in `TrainingLoopBase`
    #  inconsistently.
    def run(
        self,
        *,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader],
        test_loader: Optional[DataLoader],
        train_dataset: Dataset,
        model: ImplicitronModelBase,
        optimizer: torch.optim.Optimizer,
        scheduler: Any,
        accelerator: Optional[Accelerator],
        device: torch.device,
        exp_dir: str,
        stats: Stats,
        seed: int,
        **kwargs,
    ):
        """
        Entry point to run the training and validation loops
        based on the specified config file.
        """
        start_epoch = stats.epoch + 1
        assert scheduler.last_epoch == stats.epoch + 1
        assert scheduler.last_epoch == start_epoch

        # only run evaluation on the test dataloader
        if self.eval_only:
            if test_loader is not None:
                # pyre-fixme[16]: `Optional` has no attribute `run`.
                self.evaluator.run(
                    dataloader=test_loader,
                    device=device,
                    dump_to_json=True,
                    epoch=stats.epoch,
                    exp_dir=exp_dir,
                    model=model,
                )
                return
            else:
                raise ValueError(
                    "Cannot evaluate and dump results to json, no test data provided."
                )

        # loop through epochs
        for epoch in range(start_epoch, self.max_epochs):
            # automatic new_epoch and plotting of stats at every epoch start
            with stats:
                # Make sure to re-seed random generators to ensure reproducibility
                # even after restart.
                seed_all_random_engines(seed + epoch)

                cur_lr = float(scheduler.get_last_lr()[-1])
                logger.debug(f"scheduler lr = {cur_lr:1.2e}")

                # train loop
                self._training_or_validation_epoch(
                    accelerator=accelerator,
                    device=device,
                    epoch=epoch,
                    loader=train_loader,
                    model=model,
                    optimizer=optimizer,
                    stats=stats,
                    validation=False,
                )

                # val loop (optional)
                if val_loader is not None and epoch % self.validation_interval == 0:
                    self._training_or_validation_epoch(
                        accelerator=accelerator,
                        device=device,
                        epoch=epoch,
                        loader=val_loader,
                        model=model,
                        optimizer=optimizer,
                        stats=stats,
                        validation=True,
                    )

                # eval loop (optional)
                if (
                    test_loader is not None
                    and self.test_interval > 0
                    and epoch % self.test_interval == 0
                ):
                    self.evaluator.run(
                        device=device,
                        dataloader=test_loader,
                        model=model,
                    )

                assert stats.epoch == epoch, "inconsistent stats!"
                self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats)

                scheduler.step()
                new_lr = float(scheduler.get_last_lr()[-1])
                if new_lr != cur_lr:
                    logger.info(f"LR change! {cur_lr} -> {new_lr}")

        if self.test_when_finished:
            if test_loader is not None:
                self.evaluator.run(
                    device=device,
                    dump_to_json=True,
                    epoch=stats.epoch,
                    exp_dir=exp_dir,
                    dataloader=test_loader,
                    model=model,
                )
            else:
                raise ValueError(
                    "Cannot evaluate and dump results to json, no test data provided."
                )

    def load_stats(
        self,
        log_vars: List[str],
        exp_dir: str,
        resume: bool = True,
        resume_epoch: int = -1,
        **kwargs,
    ) -> Stats:
        """
        Load Stats that correspond to the model's log_vars and resume_epoch.

        Args:
            log_vars: A list of variable names to log. Should be a subset of the
                `preds` returned by the forward function of the corresponding
                ImplicitronModelBase instance.
            exp_dir: Root experiment directory.
            resume: If False, do not load stats from the checkpoint speci-
                fied by resume and resume_epoch; instead, create a fresh stats object.

        stats: The stats structure (optionally loaded from checkpoint)
        """
        # Init the stats struct
        visdom_env_charts = (
            vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
        )
        stats = Stats(
            # log_vars should be a list, but OmegaConf might load them as ListConfig
            list(log_vars),
            plot_file=os.path.join(exp_dir, "train_stats.pdf"),
            visdom_env=visdom_env_charts,
            visdom_server=self.visdom_server,
            visdom_port=self.visdom_port,
        )

        model_path = None
        if resume:
            if resume_epoch > 0:
                model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
                if not os.path.isfile(model_path):
                    raise FileNotFoundError(
                        f"Cannot find stats from epoch {resume_epoch}."
                    )
            else:
                model_path = model_io.find_last_checkpoint(exp_dir)

        if model_path is not None:
            stats_path = model_io.get_stats_path(model_path)
            stats_load = model_io.load_stats(stats_path)

            # Determine if stats should be reset
            if resume:
                if stats_load is None:
                    logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
                    last_epoch = model_io.parse_epoch_from_model_path(model_path)
                    logger.info(f"Estimated resume epoch = {last_epoch}")

                    # Reset the stats struct
                    for _ in range(last_epoch + 1):
                        stats.new_epoch()
                    assert last_epoch == stats.epoch
                else:
                    logger.info(f"Found previous stats in {stats_path} -> resuming.")
                    stats = stats_load

                # Update stats properties incase it was reset on load
                stats.visdom_env = visdom_env_charts
                stats.visdom_server = self.visdom_server
                stats.visdom_port = self.visdom_port
                stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
                stats.synchronize_logged_vars(log_vars)
            else:
                logger.info("Clearing stats")

        return stats

    def _training_or_validation_epoch(
        self,
        epoch: int,
        loader: DataLoader,
        model: ImplicitronModelBase,
        optimizer: torch.optim.Optimizer,
        stats: Stats,
        validation: bool,
        *,
        accelerator: Optional[Accelerator],
        bp_var: str = "objective",
        device: torch.device,
        **kwargs,
    ) -> None:
        """
        This is the main loop for training and evaluation including:
        model forward pass, loss computation, backward pass and visualization.

        Args:
            epoch: The index of the current epoch
            loader: The dataloader to use for the loop
            model: The model module optionally loaded from checkpoint
            optimizer: The optimizer module optionally loaded from checkpoint
            stats: The stats struct, also optionally loaded from checkpoint
            validation: If true, run the loop with the model in eval mode
                and skip the backward pass
            accelerator: An optional Accelerator instance.
            bp_var: The name of the key in the model output `preds` dict which
                should be used as the loss for the backward pass.
            device: The device on which to run the model.
        """

        if validation:
            model.eval()
            trainmode = "val"
        else:
            model.train()
            trainmode = "train"

        t_start = time.time()

        # get the visdom env name
        visdom_env_imgs = stats.visdom_env + "_images_" + trainmode
        viz = vis_utils.get_visdom_connection(
            server=stats.visdom_server,
            port=stats.visdom_port,
        )

        # Iterate through the batches
        n_batches = len(loader)
        for it, net_input in enumerate(loader):
            last_iter = it == n_batches - 1

            # move to gpu where possible (in place)
            net_input = net_input.to(device)

            # run the forward pass
            if not validation:
                optimizer.zero_grad()
                preds = model(
                    **{**net_input, "evaluation_mode": EvaluationMode.TRAINING}
                )
            else:
                with torch.no_grad():
                    preds = model(
                        **{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
                    )

            # make sure we dont overwrite something
            assert all(k not in preds for k in net_input.keys())
            # merge everything into one big dict
            preds.update(net_input)

            # update the stats logger
            stats.update(preds, time_start=t_start, stat_set=trainmode)
            # pyre-ignore [16]
            assert stats.it[trainmode] == it, "inconsistent stat iteration number!"

            # print textual status update
            if it % self.metric_print_interval == 0 or last_iter:
                std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches)
                logger.info(std_out)

            # visualize results
            if (
                (accelerator is None or accelerator.is_local_main_process)
                and self.visualize_interval > 0
                and it % self.visualize_interval == 0
            ):
                prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
                if hasattr(model, "visualize"):
                    # pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
                    model.visualize(
                        viz,
                        visdom_env_imgs,
                        preds,
                        prefix,
                    )

            # optimizer step
            if not validation:
                loss = preds[bp_var]
                assert torch.isfinite(loss).all(), "Non-finite loss!"
                # backprop
                if accelerator is None:
                    loss.backward()
                else:
                    accelerator.backward(loss)
                if self.clip_grad > 0.0:
                    # Optionally clip the gradient norms.
                    total_norm = torch.nn.utils.clip_grad_norm(
                        model.parameters(), self.clip_grad
                    )
                    if total_norm > self.clip_grad:
                        logger.debug(
                            f"Clipping gradient: {total_norm}"
                            + f" with coef {self.clip_grad / float(total_norm)}."
                        )

                optimizer.step()

    def _checkpoint(
        self,
        accelerator: Optional[Accelerator],
        epoch: int,
        exp_dir: str,
        model: ImplicitronModelBase,
        optimizer: torch.optim.Optimizer,
        stats: Stats,
    ):
        """
        Save a model and its corresponding Stats object to a file, if
        `self.store_checkpoints` is True. In addition, if
        `self.store_checkpoints_purge` is True, remove any checkpoints older
        than `self.store_checkpoints_purge` epochs old.
        """
        if self.store_checkpoints and (
            accelerator is None or accelerator.is_local_main_process
        ):
            if self.store_checkpoints_purge > 0:
                for prev_epoch in range(epoch - self.store_checkpoints_purge):
                    model_io.purge_epoch(exp_dir, prev_epoch)
            outfile = model_io.get_checkpoint(exp_dir, epoch)
            unwrapped_model = (
                model if accelerator is None else accelerator.unwrap_model(model)
            )
            model_io.safe_save_model(
                unwrapped_model, stats, outfile, optimizer=optimizer
            )
